package com.zbq.scbasegateway.routes.filter;

import io.github.bucket4j.Bandwidth;
import io.github.bucket4j.Bucket;
import io.github.bucket4j.Bucket4j;
import io.github.bucket4j.Refill;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.core.annotation.Order;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Component;

import java.time.Duration;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 根据CPU的使用情况限流
 */
@Slf4j
@Order(-1000)
@Component
public class LimitByIPGatewayFilterFactory extends AbstractGatewayFilterFactory<LimitByIPGatewayFilterFactory.Config> {


    /**
     * 单机网关限流用一个ConcurrentHashMap来存储 bucket，
     * 如果是分布式集群限流的话，可以采用 Redis等分布式解决方案
     */
    private static final Map<String, Bucket> LOCAL_CACHE = new ConcurrentHashMap<>();

    public LimitByIPGatewayFilterFactory() {
        super(Config.class);
    }

    @Override
    public GatewayFilter apply(Config config) {
        return (exchange, chain) -> {
            String ip = exchange.getRequest().getRemoteAddress().getAddress().getHostAddress();
            Bucket bucket = LOCAL_CACHE.computeIfAbsent(ip, k -> createNewBucket(config));
            log.debug("IP:{} ,令牌通可用的Token数量:{} ", ip, bucket.getAvailableTokens());
            if (bucket.tryConsume(1)) {
                return chain.filter(exchange);
            } else {
                //当可用的令牌书为0是，进行限流返回429状态码
                exchange.getResponse().setStatusCode(HttpStatus.TOO_MANY_REQUESTS);
                return exchange.getResponse().setComplete();
            }
        };
    }

    @Data
    public static class Config {

        /**
         * 桶的最大容量，即能装载 Token 的最大数量
         */
        int capacity;
        /**
         * 每次 Token 补充量
         */
        int refillTokens;
        /**
         * 补充 Token 的时间间隔
         */
        Duration refillDuration;
    }

    private Bucket createNewBucket(Config config) {
        Refill refill = Refill.of(config.refillTokens, config.refillDuration);
        Bandwidth limit = Bandwidth.classic(config.capacity, refill);
        return Bucket4j.builder().addLimit(limit).build();
    }

}